Skip to content

Add restart#106

Merged
chengcli merged 5 commits intomainfrom
cli/restart
Jan 18, 2026
Merged

Add restart#106
chengcli merged 5 commits intomainfrom
cli/restart

Conversation

@chengcli
Copy link
Owner

No description provided.

Copilot AI review requested due to automatic review settings January 17, 2026 22:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds restart functionality to the simulation framework, allowing simulations to resume from previously saved state files. The implementation supports loading restart data from both single .part TorchScript files and tar archives containing multiple rank-specific .part files.

Changes:

  • Added restart file loading infrastructure with support for tar archives and single .part files
  • Modified MeshBlock initialization to accept an optional restart file parameter
  • Updated CMake configuration to include libarchive dependency (v3.8.5)

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
src/mesh/meshblock.hpp Added restart_file parameter to initialize method; updated _init_from_restart signature to accept filename
src/mesh/meshblock.cpp Refactored restart initialization to use new load_restart function; moved tensor device transfer and timing data cleanup into _init_from_restart
src/input/read_restart_file.hpp Simplified interface - removed old helper function, added clean load_restart API
src/input/read_restart_file.cpp New implementation with tar archive support using libarchive; handles rank-based filtering of restart files
src/CMakeLists.txt Added libarchive include directory and library linkage
cmake/libarchive.cmake New CMake configuration for fetching and building libarchive v3.8.5
pyproject.toml Updated kintera dependency from >=1.2.6 to >=1.2.9
python/csrc/snapy.cpp Added Python binding for load_restart function
python/csrc/pymesh.cpp Modified initialize binding to accept optional restart_file parameter
examples/straka.cpp Updated to use new initialize API with restart parameter
examples/run_hydro.cpp Updated to use new initialize API with CLI restart filename

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

found_part = true;
auto out = parse_part_filename(name);
// find the block rank number after "block"
int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5));
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code assumes blockid starts with "block" (5 characters) and directly uses substr(5, ...) without validation. If the blockid doesn't follow this expected format (e.g., is shorter than 5 characters), this will cause undefined behavior or throw an out_of_range exception. Consider adding validation to check if blockid starts with "block" before extracting the rank number.

Copilot uses AI. Check for mistakes.
Comment on lines +72 to +92
static std::string dtype_to_string(const at::ScalarType t) {
// at::toString exists in many builds; this is safe enough.
return std::string(at::toString(t));
}

static std::string device_to_string(const at::Device& d) {
std::ostringstream oss;
oss << d;
return oss.str();
}

static std::string shape_to_string(const at::Tensor& t) {
std::ostringstream oss;
oss << "(";
for (int64_t i = 0; i < t.dim(); ++i) {
oss << t.size(i);
if (i + 1 < t.dim()) oss << ", ";
}
oss << ")";
return oss.str();
}
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These helper functions (dtype_to_string, device_to_string, shape_to_string) are defined but never used in this file. Consider removing them if they're not needed, or add a comment explaining that they're intended for future use or debugging purposes.

Copilot uses AI. Check for mistakes.
archive_read_free(ar);
} else {
// Treat as a single .part TorchScript file
std::cout << "single .part file detected\n";
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This debug output message should either be removed or changed to use a proper logging mechanism instead of directly writing to stdout. Direct console output in library code can interfere with application-level output formatting and logging.

Suggested change
std::cout << "single .part file detected\n";

Copilot uses AI. Check for mistakes.
Comment on lines +205 to +272
void load_restart(Variables& vars, std::string const& path) {
// Dispatch based on whether `path` is a .part file or a tar archive.
if (is_tar_archive(path)) {
struct archive* ar = archive_read_new();
if (!ar) {
std::cerr << path << ": failed to allocate archive reader\n";
return;
}

archive_read_support_filter_all(ar);
archive_read_support_format_all(ar);

int r = archive_read_open_filename(ar, path.c_str(), 10240);
if (r != ARCHIVE_OK) {
std::cerr << path
<< ": failed to open archive: " << archive_error_string(ar)
<< "\n";
archive_read_free(ar);
return;
}

bool found_part = false;

struct archive_entry* entry = nullptr;
while ((r = archive_read_next_header(ar, &entry)) == ARCHIVE_OK) {
const char* name_c = archive_entry_pathname(entry);
std::string name = name_c ? std::string{name_c} : std::string{};

if (ends_with(name, ".part")) {
found_part = true;
auto out = parse_part_filename(name);
// find the block rank number after "block"
int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5));
int my_rank = get_rank();
if (rank != my_rank) {
// Not for this rank; skip
archive_read_data_skip(ar);
} else {
load_pt_from_tar(vars, ar, entry);
return;
}

// Note: consume the entry data (via archive_read_data*)
// or skip it, otherwise the next header read will misbehave.
} else {
// Skip non-.part entries quickly
archive_read_data_skip(ar);
}
}

if (!found_part) {
std::cerr << path << ": no .part files found in tar archive\n";
}

if (r != ARCHIVE_EOF && r != ARCHIVE_OK) {
std::cerr << path
<< ": error while reading archive: " << archive_error_string(ar)
<< "\n";
}

archive_read_close(ar);
archive_read_free(ar);
} else {
// Treat as a single .part TorchScript file
std::cout << "single .part file detected\n";
kintera::load_tensors(vars, path);
}
}
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The load_restart function silently returns on multiple error conditions (allocation failure, file open failure, no .part files found, read errors), leaving vars potentially empty or partially populated. This makes it difficult for callers to distinguish between success and failure. Consider throwing an exception or returning a status code to indicate errors, so that the caller can handle restart failures appropriately.

Copilot uses AI. Check for mistakes.
load_restart(vars, fname);

cycle = timing_vars.at("last_cycle").item<int64_t>();
cycle = vars.at("last_cycle").item<int64_t>() - 1;
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cycle is set to last_cycle - 1, which seems counterintuitive. If the restart file contains the last completed cycle, the next cycle should typically be last_cycle + 1, not last_cycle - 1. This decrement could cause the simulation to repeat a cycle or have an incorrect cycle count. Please verify this is the intended behavior and add a comment explaining the reasoning if this is correct.

Suggested change
cycle = vars.at("last_cycle").item<int64_t>() - 1;
// last_cycle in the restart file stores the index of the last completed cycle.
// Set cycle to the next cycle to run so that the simulation continues forward.
cycle = vars.at("last_cycle").item<int64_t>() + 1;

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +274
// C/C++
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <optional>
#include <random>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>

// archive
#include <archive.h>
#include <archive_entry.h>

// kintera
#include <kintera/utils/serialize.hpp>

// snap
#include <snap/layout/layout.hpp>

namespace fs = std::filesystem;

namespace snap {

// -------------------------
// Small helpers
// -------------------------

struct RestartFields {
std::string basename;
std::string blockid;
std::string filenumber;
};

static RestartFields parse_part_filename(const std::string& name) {
constexpr std::string_view suffix = ".part";

if (name.size() <= suffix.size() ||
name.compare(name.size() - suffix.size(), suffix.size(), suffix) != 0) {
throw std::invalid_argument("filename does not end with .part");
}

// Strip ".part"
const std::string_view core(name.data(), name.size() - suffix.size());

// Find last two dots
const size_t dot2 = core.rfind('.');
if (dot2 == std::string::npos) {
throw std::invalid_argument("filename missing filenumber field");
}

const size_t dot1 = core.rfind('.', dot2 - 1);
if (dot1 == std::string::npos) {
throw std::invalid_argument("filename missing block_id field");
}

RestartFields out;
out.basename = std::string(core.substr(0, dot1));
out.blockid = std::string(core.substr(dot1 + 1, dot2 - dot1 - 1));
out.filenumber = std::string(core.substr(dot2 + 1));

if (out.basename.empty() || out.blockid.empty() || out.filenumber.empty()) {
throw std::invalid_argument("one or more filename fields are empty");
}

return out;
}

static std::string dtype_to_string(const at::ScalarType t) {
// at::toString exists in many builds; this is safe enough.
return std::string(at::toString(t));
}

static std::string device_to_string(const at::Device& d) {
std::ostringstream oss;
oss << d;
return oss.str();
}

static std::string shape_to_string(const at::Tensor& t) {
std::ostringstream oss;
oss << "(";
for (int64_t i = 0; i < t.dim(); ++i) {
oss << t.size(i);
if (i + 1 < t.dim()) oss << ", ";
}
oss << ")";
return oss.str();
}

static bool ends_with(std::string const& s, std::string const& suffix) {
return s.size() >= suffix.size() &&
s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0;
}

static bool is_tar_archive(std::string const& path) {
if (!fs::is_regular_file(path)) return false;

struct archive* ar = archive_read_new();
if (!ar) return false;

archive_read_support_filter_all(ar);
archive_read_support_format_all(ar);

// Try opening as an archive; if it succeeds, treat as tar-like.
int r = archive_read_open_filename(ar, path.c_str(), 10240);
if (r != ARCHIVE_OK) {
archive_read_free(ar);
return false;
}

// Some files might be recognized as other archive formats too; in practice
// this matches Python's "is_tarfile" intent well.
archive_read_close(ar);
archive_read_free(ar);
return true;
}

// Create a unique temp file path; not bulletproof, but good enough
static fs::path make_temp_path(std::string_view suffix) {
fs::path dir = fs::temp_directory_path();

std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<uint64_t> dis;

for (int tries = 0; tries < 20; ++tries) {
uint64_t r = dis(gen);
std::ostringstream name;
name << "tmp_" << std::hex << r << suffix;
fs::path p = dir / name.str();
if (!fs::exists(p)) return p;
}

// Fallback (very unlikely to collide)
return dir / ("tmp_fallback" + std::string(suffix));
}

static void load_pt_from_tar(Variables& vars, struct archive* ar,
struct archive_entry* entry) {
const char* name_c = archive_entry_pathname(entry);
std::string member_name =
name_c ? std::string{name_c} : std::string{"<unknown>"};

// Extract this entry into a temporary file (TorchScript loader prefers
// real/seekable file)
fs::path tmp_path = make_temp_path(".part");
std::ofstream out(tmp_path, std::ios::binary);
if (!out) {
std::cerr << "\n=== " << member_name << " ===\n";
std::cerr << " ERROR: could not create temp file: " << tmp_path.string()
<< "\n";
// Must still consume/skip entry data:
archive_read_data_skip(ar);
return;
}

std::vector<char> buf(1 << 20);
while (true) {
la_ssize_t n = archive_read_data(ar, buf.data(), buf.size());
if (n == 0) break; // end of this entry
if (n < 0) {
std::cerr << "\n=== " << member_name << " ===\n";
std::cerr << " ERROR: could not extract file from tar: "
<< archive_error_string(ar) << "\n";
out.close();
std::error_code ec;
fs::remove(tmp_path, ec);
return;
}
out.write(buf.data(), static_cast<std::streamsize>(n));
if (!out) {
std::cerr << "\n=== " << member_name << " ===\n";
std::cerr << " ERROR: failed writing temp file\n";
out.close();
std::error_code ec;
fs::remove(tmp_path, ec);
return;
}
}

out.flush();
out.close();

// load the extracted .part
kintera::load_tensors(vars, tmp_path.string());

// remove empty tensors (if any)
for (auto it = vars.begin(); it != vars.end();) {
if (!it->second.defined() || it->second.numel() == 0) {
it = vars.erase(it);
} else {
++it;
}
}

// Cleanup
std::error_code ec;
fs::remove(tmp_path, ec);
}

void load_restart(Variables& vars, std::string const& path) {
// Dispatch based on whether `path` is a .part file or a tar archive.
if (is_tar_archive(path)) {
struct archive* ar = archive_read_new();
if (!ar) {
std::cerr << path << ": failed to allocate archive reader\n";
return;
}

archive_read_support_filter_all(ar);
archive_read_support_format_all(ar);

int r = archive_read_open_filename(ar, path.c_str(), 10240);
if (r != ARCHIVE_OK) {
std::cerr << path
<< ": failed to open archive: " << archive_error_string(ar)
<< "\n";
archive_read_free(ar);
return;
}

bool found_part = false;

struct archive_entry* entry = nullptr;
while ((r = archive_read_next_header(ar, &entry)) == ARCHIVE_OK) {
const char* name_c = archive_entry_pathname(entry);
std::string name = name_c ? std::string{name_c} : std::string{};

if (ends_with(name, ".part")) {
found_part = true;
auto out = parse_part_filename(name);
// find the block rank number after "block"
int rank = std::stoi(out.blockid.substr(5, out.blockid.size() - 5));
int my_rank = get_rank();
if (rank != my_rank) {
// Not for this rank; skip
archive_read_data_skip(ar);
} else {
load_pt_from_tar(vars, ar, entry);
return;
}

// Note: consume the entry data (via archive_read_data*)
// or skip it, otherwise the next header read will misbehave.
} else {
// Skip non-.part entries quickly
archive_read_data_skip(ar);
}
}

if (!found_part) {
std::cerr << path << ": no .part files found in tar archive\n";
}

if (r != ARCHIVE_EOF && r != ARCHIVE_OK) {
std::cerr << path
<< ": error while reading archive: " << archive_error_string(ar)
<< "\n";
}

archive_read_close(ar);
archive_read_free(ar);
} else {
// Treat as a single .part TorchScript file
std::cout << "single .part file detected\n";
kintera::load_tensors(vars, path);
}
}

} // namespace snap
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new restart functionality lacks test coverage. Consider adding tests that verify: (1) loading restart files from tar archives, (2) loading single .part files, (3) correct extraction of timing data (last_time, last_cycle, file_number, next_time), (4) proper handling of missing files or corrupted archives, and (5) correct rank-based filtering of .part files in multi-rank scenarios.

Copilot uses AI. Check for mistakes.
Comment on lines +122 to +140
// Create a unique temp file path; not bulletproof, but good enough
static fs::path make_temp_path(std::string_view suffix) {
fs::path dir = fs::temp_directory_path();

std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<uint64_t> dis;

for (int tries = 0; tries < 20; ++tries) {
uint64_t r = dis(gen);
std::ostringstream name;
name << "tmp_" << std::hex << r << suffix;
fs::path p = dir / name.str();
if (!fs::exists(p)) return p;
}

// Fallback (very unlikely to collide)
return dir / ("tmp_fallback" + std::string(suffix));
}
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The temporary file creation uses a random number generator but doesn't set permissions on the created file. Consider using std::filesystem operations with explicit permissions (e.g., fs::perms::owner_read | fs::perms::owner_write) to prevent other users from accessing potentially sensitive restart data in the temporary directory.

Copilot uses AI. Check for mistakes.
hydro_w.index(interior)[IPR] = in_vars["press"];
}
using Variables = std::map<std::string, torch::Tensor>;

Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The load_restart function lacks documentation. Consider adding a comment block that describes: (1) the purpose of the function, (2) the expected format of the restart file (tar archive or single .part file), (3) what variables are loaded into the vars map, (4) what happens on error (currently silent failure), and (5) how multi-rank scenarios are handled.

Suggested change
/**
* Load model/state variables from a restart file into the given map.
*
* The restart file at \p path is expected to be either:
* - a tar archive containing one or more rank-specific `.part` files, or
* - a single `.part` file corresponding to the current rank.
*
* On success, \p vars is populated with named tensors deserialized from
* the restart file. Keys in the map correspond to variable names stored
* in the restart, and values are the associated torch::Tensor objects.
*
* Error handling:
* - The function does not throw or return an explicit error code.
* - I/O or parsing failures typically result in an empty or partially
* populated \p vars map, without additional notification.
* Callers are responsible for validating that the expected variables
* have been loaded (e.g., by checking for required keys).
*
* Multi-rank usage:
* - In multi-rank runs, this function is intended to be called
* independently by each rank.
* - The caller must provide a \p path that resolves to the appropriate
* restart data for that rank (e.g., a rank-specific `.part` file
* within a tar archive or as a standalone file).
*/

Copilot uses AI. Check for mistakes.
@chengcli chengcli merged commit e74442a into main Jan 18, 2026
1 of 3 checks passed
@chengcli chengcli deleted the cli/restart branch January 18, 2026 00:05
@github-actions
Copy link
Contributor

🎉 Released v1.2.4!

What's Changed

Full Changelog: v1.2.3...v1.2.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant